package it.uniroma2.exp.dtk;

import it.uniroma2.dtk.dt.GenericDT;
import it.uniroma2.exp.AbstractExperiment;
import it.uniroma2.exp.tools.AvgVarCalculator;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.vector.RandomVectorGenerator;
import it.uniroma2.util.vector.VectorComposer;
import it.uniroma2.util.vector.VectorProvider;

import java.util.Arrays;
import java.util.Random;
import java.util.Vector;

/**
 * @author Lorenzo Dell'Arciprete
 * 
 * This experiment is used to evaluate if a function is a good approximation for the ideal
 * vector composition function, according to its properties in Definition 2
 */
public class FunctionTester extends AbstractExperiment {
	
	public static final int TRIALS = 10;
	public static final int MAX_COMP = 10;
	public static final int BASE_SIZE = 100;
	public static final double THRESHOLD = 0.05;
	
	public VectorComposer vc;
	public VectorProvider vp;
	
	public double[] x;
	public double[] y;
	public double[] z;
	
	public static void main(String[] args) {
		FunctionTester ft = new FunctionTester();
		ft.setVectorSize(1024);
		ft.setCompositionType(GenericDT.Types.PRODUCT);
		ft.runAll();
	}

	@Override
	protected void runExperiment() throws Exception {
		vp = new RandomVectorGenerator(vectorSize);
		vc = new VectorComposer(vp);
		x = rand();
		y = rand();
		z = rand();
		testCommutativity();
		testDistributivity();
		testAssociativity();
		testBilinearity();
		testNorm();
		testSimilarity();
		testOrthogonality();
	}
	
	public double[] op(double[] v1, double[] v2) throws Exception {
		switch (compositionType) {
		case PRODUCT: return vc.shuffledGammaProduct(v1, v2);
		case CONVOLUTION: return vc.shuffledConvolution(v1, v2);
		case SHIFT_PRODUCT: return vc.shiftedGammaProduct(v1, v2);
		case SHIFT_CONVOLUTION: return vc.shiftedConvolution(v1, v2);
		case REV_PRODUCT: return vc.reverseGammaProduct(v1, v2);
		case REV_CONVOLUTION: return vc.reverseConvolution(v1, v2);
		default: throw new Exception("Unknown CDS type:" + compositionType);
		}
	}
	
	public void testCommutativity() throws Exception {
		out.println("Non commutativity");
		out.println("x#y = " + printVector(op(x,y)));
		out.println("y#x = " + printVector(op(y,x)));
	}
	
	public void testDistributivity() throws Exception {
		out.println("Distributivity (1)");
		out.println("(x+y)#z = \t" + printVector(op(vc.sum(x, y),z)));
		out.println("x#z + y#z = \t" + printVector(vc.sum(op(x,z), op(y,z))));
		out.println("Distributivity (2)");
		out.println("z#(x+y) = \t" + printVector(op(z,vc.sum(x, y))));
		out.println("z#x + z#y = \t" + printVector(vc.sum(op(z,x), op(z,y))));
		out.println("Distributivity (3)");
		double[] w = rand();
		out.println("(x+y)#(z+w) = \t\t" + printVector(op(vc.sum(x,y), vc.sum(z,w))));
		out.println("x#z+x#w+y#z+y#w = \t" + printVector(vc.sum(op(x,z), vc.sum(op(x,w), vc.sum(op(y,z), op(y,w))))));
	}
	
	public void testAssociativity() throws Exception {
		out.println("Non Associativity");
		out.println("(x#y)#z = \t" + printVector(op(op(x, y),z)));
		out.println("x#(y#z) = \t" + printVector(op(x, op(y,z))));
	}
	
	public void testBilinearity() throws Exception {
		out.println("Bilinearity");
		double c = Math.random();
		out.println("c(x#y) = \t" + printVector(ArrayMath.scalardot(c, op(x, y))));
		out.println("cx#y = \t\t" + printVector(op(ArrayMath.scalardot(c, x), y)));
		out.println("x#cy = \t\t" + printVector(op(x, ArrayMath.scalardot(c, y))));
	}
	
	/**
	 * This tests performs several kinds of vector compositions and verifies the resulting norms. See Sec. 5.1.
	 * Average and variance are computed over TRIALS values.
	 * A maximum of MAX_COMP vectors are composed for each case.
	 */
	public void testNorm() throws Exception {
		AvgVarCalculator[] norms1 = new AvgVarCalculator[MAX_COMP];
		AvgVarCalculator[] norms2 = new AvgVarCalculator[MAX_COMP];
		AvgVarCalculator[] norms3 = new AvgVarCalculator[MAX_COMP];
		for (int i=0; i<MAX_COMP; i++) {
			norms1[i] = new AvgVarCalculator();
			norms2[i] = new AvgVarCalculator();
			norms3[i] = new AvgVarCalculator();
		}
		out.println("Norm");
		for (int i=0; i<TRIALS; i++) {
			double[] v1 = new double[vectorSize];
			double[] v2 = new double[vectorSize];
			double[] v3 = new double[vectorSize];
			for (int j=0; j<MAX_COMP; j++) {
				v1 = (j == 0) ? rand() : op(v1, rand()); 
				v2 = (j == 0) ? vc.sum(rand(), rand()) : op(v2, vc.sum(rand(), rand()));
				v3 = (j == 0) ? vc.sum(rand(), vc.sum(rand(), rand())) : op(v3, vc.sum(rand(), vc.sum(rand(), rand())));
				norms1[j].addSample(ArrayMath.norm(v1));
				norms2[j].addSample(ArrayMath.norm(v2));
				norms3[j].addSample(ArrayMath.norm(v3));
			}
		}
		out.println("||x#y#z#...||");
		for (int i=0; i<MAX_COMP; i++) {
			out.println((i+1)+" vectors \t" + norms1[i].getFormattedResult(3));
		}
		out.println("||(x+y)#(z+w)#...||");
		for (int i=0; i<MAX_COMP; i++) {
			out.println((i+1)+" couples \t" + 
					String.format("%7.3f (%9.3f) [%.3f]", norms2[i].getAvg(), norms2[i].getVar(), 
							norms2[i].getAvg()/Math.pow(Math.sqrt(2),i+1)));
		}
		out.println("||(x+y+z)#(w+v+u)#...||");
		for (int i=0; i<MAX_COMP; i++) {
			out.println((i+1)+" triplets \t" + 
					String.format("%8.3f (%11.3f) [%6.3f]", norms3[i].getAvg(), norms3[i].getVar(), 
							norms3[i].getAvg()/Math.pow(Math.sqrt(3),i+1)));
		}
		AvgVarCalculator norm = new AvgVarCalculator();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			z = rand();
			norm.addSample(ArrayMath.norm(op(vc.sum(x, y),z)));
		}
		out.println("||(x+y)#z|| \t" + String.format("%.3f (%.3f)", norm.getAvg(), norm.getVar()));
		norm.reset();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			z = rand();
			norm.addSample(ArrayMath.norm(op(vc.sum(x, y),vc.sum(x, z))));
		}
		out.println("||(x+y)#(x+z)|| \t" + String.format("%.3f (%.3f)", norm.getAvg(), norm.getVar()));
		norm.reset();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			z = rand();
			norm.addSample(ArrayMath.norm(vc.sum(op(x, y),op(x, z))));
		}
		out.println("||x#y + x#z|| \t" + String.format("%.3f (%.3f)", norm.getAvg(), norm.getVar()));
		norm.reset();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			z = rand();
			norm.addSample(ArrayMath.norm(op(x, op(y, x))));
		}
	}
	
	/**
	 * This tests performs several kinds of vector compositions and verifies their similarities. See Sec. 5.1.
	 * Average and variance are computed over TRIALS values.
	 * A maximum of MAX_COMP vectors are composed for each case.
	 */
	public void testSimilarity() throws Exception {
		out.println("Similarity");
		AvgVarCalculator simx = new AvgVarCalculator();
		AvgVarCalculator simy = new AvgVarCalculator();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			simx.addSample(ArrayMath.dot(x, op(x, y)));
			simy.addSample(ArrayMath.dot(y, op(x, y)));
		}
		out.println("x ° x#y \t" + simx.getFormattedResult(5));
		out.println("y ° x#y \t" + simy.getFormattedResult(5));
		AvgVarCalculator[] sims = new AvgVarCalculator[MAX_COMP];
		for (int i=0; i<MAX_COMP; i++)
			sims[i] = new AvgVarCalculator();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			double normx = ArrayMath.norm(x);
			for (int j=0; j<MAX_COMP; j++) {
				y = op(y,rand());
				sims[j].addSample(ArrayMath.dot(x, y)/(normx*ArrayMath.norm(y)));
			}
		}
		out.println("x ° y#z#w#... (normalized)");
		for (int i=0; i<MAX_COMP; i++) {
			out.println((i+1)+" compositions \t" + sims[i].getFormattedResult(5));
		}
		for (int i=0; i<MAX_COMP; i++)
			sims[i] = new AvgVarCalculator();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = Arrays.copyOf(x, x.length);
			double normx = ArrayMath.norm(x);
			for (int j=0; j<MAX_COMP; j++) {
				y = op(y, rand());
				sims[j].addSample(ArrayMath.dot(x, y)/(normx*ArrayMath.norm(y)));
			}
		}
		out.println("x ° x#y#z#... (normalized)");
		for (int i=0; i<MAX_COMP; i++) {
			out.println((i+1)+" compositions \t" + sims[i].getFormattedResult(5));
		}
		for (int i=0; i<MAX_COMP; i++)
			sims[i] = new AvgVarCalculator();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			for (int j=0; j<MAX_COMP; j++) {
				z = rand();
				x = op(x, z);
				y = op(y, z);
				sims[j].addSample(ArrayMath.dot(x, y)/(ArrayMath.norm(x)*ArrayMath.norm(y)));
			}
		}
		out.println("x#z#w#... ° y#z#w#... (normalized)");
		for (int i=0; i<MAX_COMP; i++) {
			out.println((i+1)+" compositions \t" + sims[i].getFormattedResult(5));
		}
		for (int i=0; i<MAX_COMP; i++)
			sims[i] = new AvgVarCalculator();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			z = rand();
			for (int j=0; j<MAX_COMP; j++) {
				double[] v = op(z, x);
				double[] w = op(z, y);
				sims[j].addSample(ArrayMath.dot(v, w)/(ArrayMath.norm(v)*ArrayMath.norm(w)));
				z = op(z, rand());
			}
		}
		out.println("z#w#...#x ° z#w#...#y (normalized)");
		for (int i=0; i<MAX_COMP; i++) {
			out.println((i+1)+" compositions \t" + sims[i].getFormattedResult(5));
		}
		AvgVarCalculator sim = new AvgVarCalculator();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			z = rand();
			sim.addSample(ArrayMath.dot(op(x, op(y, z)), op(z, op(y, x))));
		}
		out.println("x#y#z ° z#y#x \t" + String.format("%.5f (%.5f)", sim.getAvg(), sim.getVar()));
		sim.reset();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			z = rand();
			sim.addSample(ArrayMath.dot(op(x, op(y, z)), op(op(x, y), z)));
		}
		out.println("x#(y#z) ° (x#y)#z \t" + String.format("%.5f (%.5f)", sim.getAvg(), sim.getVar()));
		sim.reset();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			z = rand();
			sim.addSample(ArrayMath.dot(op(x, op(y, z)), op(x, y)));
		}
		out.println("x#(y#z) ° x#y \t" + String.format("%.5f (%.5f)", sim.getAvg(), sim.getVar()));
		sim.reset();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			z = rand();
			sim.addSample(ArrayMath.dot(op(x, y), op(x, z)));
		}
		out.println("x#y ° x#z \t" + String.format("%.5f (%.5f)", sim.getAvg(), sim.getVar()));
		sim.reset();
		for (int i=0; i<TRIALS; i++) {
			x = rand();
			y = rand();
			sim.addSample(ArrayMath.dot(x, y));
		}
		out.println("x ° y \t" + String.format("%.5f (%.5f)", sim.getAvg(), sim.getVar()));
	}
	
	/**
	 * This test creates a set of BASE_SIZE nearly orthogonal vectors. Then, it starts combining 
	 * couples of random vectors in the set: if the result is still nearly orthogonal wrt each other 
	 * element, it is added to the set. After TRIALS failed tries, the test ends. The number of new 
	 * nearly orthogonal vectors generated is report, together with the number of failed tries.
	 * Two vectors are considered nearly orthogonal if their cosine similarity is below THRESHOLD.
	 */
	public void testOrthogonality() throws Exception {
		out.println("Orthogonality");
		Random r = new Random(randomOffset);
		Vector<double[]> vectors = new Vector<double[]>();
		boolean isGood;
		for (int i=0; i<BASE_SIZE; i++) {
			isGood = false;
			double[] newVector = null;
			while(!isGood) {
				newVector = rand();
				isGood = true;
				for (double[] vec : vectors)
					if (ArrayMath.cosine(newVector, vec) > THRESHOLD) {
						isGood = false;
						break;
					}
			}
			vectors.add(newVector);
		}
		out.println("A set of "+BASE_SIZE+" nearly orthogonal vectors has been built...");
		int retries = 0;
		int errors = 0;
		while (retries < TRIALS) {
			int id1 = (int) (r.nextDouble()*vectors.size());
			int id2 = (int) (r.nextDouble()*vectors.size());
			while (id1 == id2)
				id2 = (int) (r.nextDouble()*vectors.size());
			double[] newVector = op(vectors.get(id1), vectors.get(id2));
			isGood = true;
			for (double[] vec : vectors)
				if (ArrayMath.cosine(newVector, vec) > THRESHOLD) {
					isGood = false;
					break;
				}
			if (isGood) {
				vectors.add(newVector);
				errors += retries;
				retries = 0;
				if (vectors.size() % 50 == 0)
					out.print(vectors.size() + "...");
				if (vectors.size() % 1000 == 0)
					out.println();
			}
			else
				retries++;
		}
		out.println("\n" + (vectors.size()-BASE_SIZE) + " more valid vectors generated with " + (errors+retries) + " discarded");
	}
	
	public String printVector(double[] vec) {
		String res = "";
		for (double d : vec) {
			if (res.length() > 200) {
				res += "...";
				break;
			}
			res += d + "\t";
		}
		return res;
	}
	
	public double[] rand() {
		return vp.generateRandomVector();
	}
}
